
## IMPORT

from pyspark.context import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.context import SQLContext
from pyspark.sql.functions import *
from graphframes import *

## SPARK SESSION 

sc = SparkContext.getOrCreate(conf=SparkConf())
spark = SparkSession(sc)
spark.sparkContext.setLogLevel('WARN')
sc.setCheckpointDir("splink_sandbox/temp_graphframes/")

# Load L2

l2_14 = spark.read.parquet('data/l2/year=2014').select('LALVOTERID', 'emmid')
l2_16 = spark.read.parquet('data/l2/year=2016').select('LALVOTERID', 'emmid')
l2_20 = spark.read.parquet('data/l2/year=2020').select('LALVOTERID', 'emmid')

l2_keys = l2_14.union(l2_16).union(l2_20)

# Load CoreLogic / FEC

cl_fec = spark.read.parquet('data/cl_fec_matches/*/*/*/') \
      .select('uid_l', 'uid_r') \
      .withColumn('uid_r', regexp_replace('uid_r', '.{2}$', ''))

# Load CoreLogic / L2

cl_l2 = spark.read.parquet('data/cl_l2_matches/*/*/*/') \
      .select('uid_l', 'uid_r') \
      .withColumn('uid_l', regexp_replace('uid_l', '.{2}$', ''))

print(f"There are {cl_l2.count()} rows before join...")

cl_l2 = cl_l2 \
      .join(l2_keys, cl_l2.uid_r == l2_keys.emmid, 'left') \
      .select('uid_l', 'LALVOTERID') \
      .withColumnRenamed('LALVOTERID', 'uid_r')

print(f"There are {cl_l2.count()} rows after join...")

# Load FEC / L2

fec_l2 = spark.read.parquet('data/fec_l2_matches_all/*/*/*/').select('uid_l', 'uid_r')

print(f"There are {fec_l2.count()} rows before join...")

fec_l2 = fec_l2.join(l2_keys, fec_l2.uid_r == l2_keys.emmid, 'left') \
      .select('uid_l', 'LALVOTERID') \
      .withColumnRenamed('LALVOTERID', 'uid_r')

print(f"There are {fec_l2.count()} rows after join...")

# Load CoreLogic -- CoreLogic matches

cl_site = spark.read.parquet('data/cl_site_across/') \
      .select('uid_l', 'uid_r') # already truncated

cl = spark.read.parquet('data/cl_dedupe_matches_all/*/*/') \
      .select('uid_l', 'uid_r') \
      .withColumn('uid_l', regexp_replace('uid_l', '.{2}$', '')) \
      .withColumn('uid_r', regexp_replace('uid_r', '.{2}$', ''))

# All edges for clustering input

edgesDf = cl_fec \
      .union(cl_l2) \
      .union(fec_l2) \
      .union(cl_site) \
      .union(cl) \
      .withColumnRenamed('uid_l', 'src') \
      .withColumnRenamed('uid_r', 'dst')

print(f"There are {edgesDf.count()} edges...")

verticesDf = edgesDf \
     .select("src") \
     .union(edgesDf.select("dst")) \
     .distinct() \
     .withColumnRenamed('src', 'id')

print(f"There are {verticesDf.count()} unique vertices...")

g = GraphFrame(verticesDf, edgesDf)

out = g.connectedComponents().select("id", "component")

out.coalesce(100).write.mode("overwrite").parquet(f"data/all_components")
